{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id='1'></a>\n",
    "# 1. Import packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from keras.models import Sequential, Model\n",
    "from keras.layers import *\n",
    "from keras.layers.advanced_activations import LeakyReLU\n",
    "from keras.activations import relu\n",
    "from keras.initializers import RandomNormal\n",
    "from keras.applications import *\n",
    "import keras.backend as K\n",
    "from tensorflow.contrib.distributions import Beta\n",
    "import tensorflow as tf\n",
    "from keras.optimizers import Adam"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from image_augmentation import random_transform\n",
    "from image_augmentation import random_warp\n",
    "from utils import get_image_paths, load_images, stack_images\n",
    "from pixel_shuffler import PixelShuffler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "import cv2\n",
    "import glob\n",
    "from random import randint, shuffle\n",
    "from IPython.display import clear_output\n",
    "from IPython.display import display\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id='4'></a>\n",
    "# 4. Config\n",
    "\n",
    "mixup paper: https://arxiv.org/abs/1710.09412\n",
    "\n",
    "Default training data directories: `./faceA/` and `./faceB/`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "K.set_learning_phase(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "channel_axis=-1\n",
    "channel_first = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "IMAGE_SHAPE = (64, 64, 3)\n",
    "nc_in = 3 # number of input channels of generators\n",
    "nc_D_inp = 6 # number of input channels of discriminators\n",
    "\n",
    "use_self_attn = False\n",
    "w_l2 = 1e-4 # weight decay\n",
    "\n",
    "batchSize = 8\n",
    "\n",
    "# Path of training images\n",
    "img_dirA = './faceA/*.*'\n",
    "img_dirB = './faceB/*.*'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id='5'></a>\n",
    "# 5. Define models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Scale(Layer):\n",
    "    '''\n",
    "    Code borrows from https://github.com/flyyufelix/cnn_finetune\n",
    "    '''\n",
    "    def __init__(self, weights=None, axis=-1, gamma_init='zero', **kwargs):\n",
    "        self.axis = axis\n",
    "        self.gamma_init = initializers.get(gamma_init)\n",
    "        self.initial_weights = weights\n",
    "        super(Scale, self).__init__(**kwargs)\n",
    "\n",
    "    def build(self, input_shape):\n",
    "        self.input_spec = [InputSpec(shape=input_shape)]\n",
    "\n",
    "        # Compatibility with TensorFlow >= 1.0.0\n",
    "        self.gamma = K.variable(self.gamma_init((1,)), name='{}_gamma'.format(self.name))\n",
    "        self.trainable_weights = [self.gamma]\n",
    "\n",
    "        if self.initial_weights is not None:\n",
    "            self.set_weights(self.initial_weights)\n",
    "            del self.initial_weights\n",
    "\n",
    "    def call(self, x, mask=None):\n",
    "        return self.gamma * x\n",
    "\n",
    "    def get_config(self):\n",
    "        config = {\"axis\": self.axis}\n",
    "        base_config = super(Scale, self).get_config()\n",
    "        return dict(list(base_config.items()) + list(config.items()))\n",
    "\n",
    "\n",
    "def self_attn_block(inp, nc):\n",
    "    '''\n",
    "    Code borrows from https://github.com/taki0112/Self-Attention-GAN-Tensorflow\n",
    "    '''\n",
    "    assert nc//8 > 0, f\"Input channels must be >= 8, but got nc={nc}\"\n",
    "    x = inp\n",
    "    shape_x = x.get_shape().as_list()\n",
    "    \n",
    "    f = Conv2D(nc//8, 1, kernel_initializer=conv_init)(x)\n",
    "    g = Conv2D(nc//8, 1, kernel_initializer=conv_init)(x)\n",
    "    h = Conv2D(nc, 1, kernel_initializer=conv_init)(x)\n",
    "    \n",
    "    shape_f = f.get_shape().as_list()\n",
    "    shape_g = g.get_shape().as_list()\n",
    "    shape_h = h.get_shape().as_list()\n",
    "    flat_f = Reshape((-1, shape_f[-1]))(f)\n",
    "    flat_g = Reshape((-1, shape_g[-1]))(g)\n",
    "    flat_h = Reshape((-1, shape_h[-1]))(h)   \n",
    "    \n",
    "    s = Lambda(lambda x: tf.matmul(x[0], x[1], transpose_b=True))([flat_g, flat_f])\n",
    "\n",
    "    beta = Softmax(axis=-1)(s)\n",
    "    o = Lambda(lambda x: tf.matmul(x[0], x[1]))([beta, flat_h])\n",
    "    o = Reshape(shape_x[1:])(o)\n",
    "    o = Scale()(o)\n",
    "    \n",
    "    out = add([o, inp])\n",
    "    return out\n",
    "\n",
    "def conv_block(input_tensor, f):\n",
    "    x = input_tensor\n",
    "    x = Conv2D(f, kernel_size=3, strides=2, kernel_regularizer=regularizers.l2(w_l2),  \n",
    "               kernel_initializer=conv_init, use_bias=False, padding=\"same\")(x)\n",
    "    x = Activation(\"relu\")(x)\n",
    "    return x\n",
    "\n",
    "def conv_block_d(input_tensor, f, use_instance_norm=False):\n",
    "    x = input_tensor\n",
    "    x = Conv2D(f, kernel_size=4, strides=2, kernel_regularizer=regularizers.l2(w_l2), \n",
    "               kernel_initializer=conv_init, use_bias=False, padding=\"same\")(x)\n",
    "    x = LeakyReLU(alpha=0.2)(x)\n",
    "    return x\n",
    "\n",
    "def res_block(input_tensor, f):\n",
    "    x = input_tensor\n",
    "    x = Conv2D(f, kernel_size=3, kernel_regularizer=regularizers.l2(w_l2), \n",
    "               kernel_initializer=conv_init, use_bias=False, padding=\"same\")(x)\n",
    "    x = LeakyReLU(alpha=0.2)(x)\n",
    "    x = Conv2D(f, kernel_size=3, kernel_regularizer=regularizers.l2(w_l2), \n",
    "               kernel_initializer=conv_init, use_bias=False, padding=\"same\")(x)\n",
    "    x = add([x, input_tensor])\n",
    "    x = LeakyReLU(alpha=0.2)(x)\n",
    "    return x\n",
    "\n",
    "def upscale_ps(filters, use_norm=True):\n",
    "    def block(x):\n",
    "        x = Conv2D(filters*4, kernel_size=3, kernel_regularizer=regularizers.l2(w_l2), \n",
    "                   kernel_initializer=RandomNormal(0, 0.02), padding='same')(x)\n",
    "        x = LeakyReLU(0.2)(x)\n",
    "        x = PixelShuffler()(x)\n",
    "        return x\n",
    "    return block\n",
    "\n",
    "def Discriminator(nc_in, input_size=64):\n",
    "    inp = Input(shape=(input_size, input_size, nc_in))\n",
    "    #x = GaussianNoise(0.05)(inp)\n",
    "    x = conv_block_d(inp, 64, False)\n",
    "    x = conv_block_d(x, 128, False)\n",
    "    x = self_attn_block(x, 128) if use_self_attn else x\n",
    "    x = conv_block_d(x, 256, False)\n",
    "    x = self_attn_block(x, 256) if use_self_attn else x\n",
    "    out = Conv2D(1, kernel_size=4, kernel_initializer=conv_init, use_bias=False, padding=\"same\")(x)   \n",
    "    return Model(inputs=[inp], outputs=out)\n",
    "\n",
    "def Encoder(nc_in=3, input_size=64):\n",
    "    inp = Input(shape=(input_size, input_size, nc_in))\n",
    "    x = Conv2D(64, kernel_size=5, kernel_initializer=conv_init, use_bias=False, padding=\"same\")(inp)\n",
    "    x = conv_block(x,128)\n",
    "    x = conv_block(x,256)\n",
    "    x = self_attn_block(x, 256) if use_self_attn else x\n",
    "    x = conv_block(x,512) \n",
    "    x = self_attn_block(x, 512) if use_self_attn else x\n",
    "    x = conv_block(x,1024)\n",
    "    x = Dense(1024)(Flatten()(x))\n",
    "    x = Dense(4*4*1024)(x)\n",
    "    x = Reshape((4, 4, 1024))(x)\n",
    "    out = upscale_ps(512)(x)\n",
    "    return Model(inputs=inp, outputs=out)\n",
    "\n",
    "def Decoder_ps(nc_in=512, input_size=8):\n",
    "    input_ = Input(shape=(input_size, input_size, nc_in))\n",
    "    x = input_\n",
    "    x = upscale_ps(256)(x)\n",
    "    x = upscale_ps(128)(x)\n",
    "    x = self_attn_block(x, 128) if use_self_attn else x\n",
    "    x = upscale_ps(64)(x)\n",
    "    x = res_block(x, 64)\n",
    "    x = self_attn_block(x, 64) if use_self_attn else x\n",
    "    #x = Conv2D(4, kernel_size=5, padding='same')(x)   \n",
    "    alpha = Conv2D(1, kernel_size=5, padding='same', activation=\"sigmoid\")(x)\n",
    "    rgb = Conv2D(3, kernel_size=5, padding='same', activation=\"tanh\")(x)\n",
    "    out = concatenate([alpha, rgb])\n",
    "    return Model(input_, out)    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "encoder = Encoder()\n",
    "decoder_A = Decoder_ps()\n",
    "decoder_B = Decoder_ps()\n",
    "\n",
    "x = Input(shape=IMAGE_SHAPE)\n",
    "\n",
    "netGA = Model(x, decoder_A(encoder(x)))\n",
    "netGB = Model(x, decoder_B(encoder(x)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "netDA = Discriminator(nc_D_inp)\n",
    "netDB = Discriminator(nc_D_inp)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id='6'></a>\n",
    "# 6. Load Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model loaded.\n"
     ]
    }
   ],
   "source": [
    "try:\n",
    "    encoder.load_weights(\"models/encoder.h5\")\n",
    "    decoder_A.load_weights(\"models/decoder_A.h5\")\n",
    "    decoder_B.load_weights(\"models/decoder_B.h5\")\n",
    "    #netDA.load_weights(\"models/netDA.h5\") \n",
    "    #netDB.load_weights(\"models/netDB.h5\") \n",
    "    print (\"model loaded.\")\n",
    "except:\n",
    "    print (\"Weights file not found.\")\n",
    "    pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id='7'></a>\n",
    "# 7. Define Inputs/Outputs Variables\n",
    "\n",
    "    distorted_A: A (batch_size, 64, 64, 3) tensor, input of generator_A (netGA).\n",
    "    distorted_B: A (batch_size, 64, 64, 3) tensor, input of generator_B (netGB).\n",
    "    fake_A: (batch_size, 64, 64, 3) tensor, output of generator_A (netGA).\n",
    "    fake_B: (batch_size, 64, 64, 3) tensor, output of generator_B (netGB).\n",
    "    mask_A: (batch_size, 64, 64, 1) tensor, mask output of generator_A (netGA).\n",
    "    mask_B: (batch_size, 64, 64, 1) tensor, mask output of generator_B (netGB).\n",
    "    path_A: A function that takes distorted_A as input and outputs fake_A.\n",
    "    path_B: A function that takes distorted_B as input and outputs fake_B.\n",
    "    path_mask_A: A function that takes distorted_A as input and outputs mask_A.\n",
    "    path_mask_B: A function that takes distorted_B as input and outputs mask_B.\n",
    "    path_abgr_A: A function that takes distorted_A as input and outputs concat([mask_A, fake_A]).\n",
    "    path_abgr_B: A function that takes distorted_B as input and outputs concat([mask_B, fake_B]).\n",
    "    real_A: A (batch_size, 64, 64, 3) tensor, target images for generator_A given input distorted_A.\n",
    "    real_B: A (batch_size, 64, 64, 3) tensor, target images for generator_B given input distorted_B."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cycle_variables(netG):\n",
    "    distorted_input = netG.inputs[0]\n",
    "    fake_output = netG.outputs[0]\n",
    "    alpha = Lambda(lambda x: x[:,:,:, :1])(fake_output)\n",
    "    rgb = Lambda(lambda x: x[:,:,:, 1:])(fake_output)\n",
    "    \n",
    "    masked_fake_output = alpha * rgb + (1-alpha) * distorted_input \n",
    "\n",
    "    fn_generate = K.function([distorted_input], [masked_fake_output])\n",
    "    fn_mask = K.function([distorted_input], [concatenate([alpha, alpha, alpha])])\n",
    "    fn_abgr = K.function([distorted_input], [concatenate([alpha, rgb])])\n",
    "    return distorted_input, fake_output, alpha, fn_generate, fn_mask, fn_abgr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "distorted_A, fake_A, mask_A, path_A, path_mask_A, path_abgr_A = cycle_variables(netGA)\n",
    "distorted_B, fake_B, mask_B, path_B, path_mask_B, path_abgr_B = cycle_variables(netGB)\n",
    "real_A = Input(shape=IMAGE_SHAPE)\n",
    "real_B = Input(shape=IMAGE_SHAPE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id='11'></a>\n",
    "# 11. Helper Function: face_swap()\n",
    "This function is provided for those who don't have enough VRAM to run dlib's CNN and GAN model at the same time.\n",
    "\n",
    "    INPUTS:\n",
    "        img: A RGB face image of any size.\n",
    "        path_func: a function that is either path_abgr_A or path_abgr_B.\n",
    "    OUPUTS:\n",
    "        result_img: A RGB swapped face image after masking.\n",
    "        result_mask: A single channel uint8 mask image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "def swap_face(img, path_func):\n",
    "    input_size = img.shape\n",
    "    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # generator expects BGR input    \n",
    "    ae_input = cv2.resize(img, (64,64))/255. * 2 - 1        \n",
    "    \n",
    "    result = np.squeeze(np.array([path_func([[ae_input]])]))\n",
    "    result_a = result[:,:,0] * 255\n",
    "    result_a = cv2.resize(result_a, (input_size[1],input_size[0]))[...,np.newaxis]\n",
    "    result_bgr = np.clip( (result[:,:,1:] + 1) * 255 / 2, 0, 255)\n",
    "    result_bgr = cv2.resize(result_bgr, (input_size[1],input_size[0]))\n",
    "    result = (result_a/255 * result_bgr + (1 - result_a/255) * img).astype('uint8')\n",
    "       \n",
    "    result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB) \n",
    "    result = cv2.resize(result, (input_size[1],input_size[0]))\n",
    "    result_a = np.expand_dims(cv2.resize(result_a, (input_size[1],input_size[0])), axis=2)\n",
    "    return result, result_a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "whom2whom = \"BtoA\" # default trainsforming faceB to faceA\n",
    "\n",
    "if whom2whom is \"AtoB\":\n",
    "    path_func = path_abgr_B\n",
    "elif whom2whom is \"BtoA\":\n",
    "    path_func = path_abgr_A\n",
    "else:\n",
    "    print (\"whom2whom should be either AtoB or BtoA\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_img = plt.imread(\"./IMAGE_FILENAME.jpg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(input_img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_img, result_mask = swap_face(input_img, path_func)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(result_img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(result_mask[:, :, 0]) # cmap='gray'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
